Skip to content

Conversation

jwjohns
Copy link
Contributor

@jwjohns jwjohns commented Aug 25, 2025

This PR adds initial support for the Nemotron-H hybrid architecture used by NVIDIA Nemotron Nano V2
models. Nemotron-H combines Mamba2 state-space model layers with selective transformer attention layers
for efficient inference.

Note: This PR will remain in DRAFT status until full inference is working.

Architecture Details

  • Total layers: 56 (27 SSM + 25 MLP + 4 attention at positions [14, 21, 30, 39])
  • Hybrid pattern: Alternating Mamba2 SSM layers with periodic attention layers
  • Key components: SSM conv1d, A/D parameters, normalization, x/z gating with SiLU

current issue

SSM conv assertion failure during graph building

GGML_ASSERT(sx->ne[1] == d_inner) failed in ggml_ssm_conv, the second dimension
   of tensor sx doesn't match the expected d_inner value.

what works

  1. Model loading works - the tensor dimension issue is resolved
  2. All tensors loaded correctly - no more "wrong shape" errors

Addresses #15409 - Request for Nemotron-H architecture support

jwjohns and others added 14 commits August 23, 2025 15:20
- Add custom cache initialization filters for LLM_ARCH_NEMOTRON_H
- Attention cache only allocated for layers 14, 21, 30, 39 (attention layers)
- Recurrent cache only allocated for SSM layers using is_recurrent()
- Reduces KV cache memory usage from 264MB (29 layers) to 64MB (4 layers)
- Implements proper Mamba2-style SSM with x/z gating and SiLU activation
- Resolves infinite hang issue during token generation
- fixed A/D tensor shapes from [128,1,1,1] to [1,128]
- fixed conv1d dimensions to use actual 12288 not 17728
- fixed ssm_norm and ssm_out tensor sizes to use 10240
- fixed layer_types array type from uint8 to int32
- fixed gguf numpy array serialization
- added missing template instantiations
- model now loads to tensor validation stage
- created working 18GB gguf file
  that tries both orientations
@gabe-l-hart
Copy link
Collaborator

I hit that same assertion. Here's what I see in my debugger:

   236 	// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp
(lldb) up 1
frame #4: 0x000000010110db30 libggml-base.dylib`ggml_ssm_conv(ctx=0x00006000022add40, sx=0x000000016015d3e0, c=0x00000001010585e0) at ggml.c:5033:5
   5030	
   5031	    // TODO: maybe support other strides than 1?
   5032	    GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
-> 5033	    GGML_ASSERT(sx->ne[1] == d_inner);
   5034	    GGML_ASSERT(n_t >= 0);
   5035	
   5036	    struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
(lldb) p sx->ne[1]
(int64_t) 17728
(lldb) p d_inner
(const int64_t) 12288

so the sx->ne[1] is the full computed conv shape (versus the overwritten one)

@github-actions github-actions bot added examples python python script changes server labels Aug 25, 2025
@jwjohns
Copy link
Contributor Author

jwjohns commented Aug 25, 2025

same for me

/Development/Nemotron/llama.cpp/ggml/src/ggml.c-5031-    // TODO: maybe support
     other strides than 1?
     Development/Nemotron/llama.cpp/ggml/src/ggml.c-5032-    GGML_ASSERT(sx->ne[0] ==
     d_conv - 1 + n_t);
     Development/Nemotron/llama.cpp/ggml/src/ggml.c:5033:    GGML_ASSERT(sx->ne[1] ==
     d_inner);
     /Development/Nemotron/llama.cpp/ggml/src/ggml.c-5034-    GGML_ASSERT(n_t >= 0);
     /Development/Nemotron/llama.cpp/ggml/src/ggml.c-5035-
     /Development/Nemotron/llama.cpp/ggml/src/ggml.c-5036-    struct ggml_tensor * result
      = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);

@gabe-l-hart
Copy link
Collaborator

The sx->ne[0] dimension is coming from conv_x here

@isaac-mcfadyen
Copy link
Contributor

Is this a duplicate of #15507?

@gabe-l-hart
Copy link
Collaborator

@isaac-mcfadyen It is a duplicate. We both started working on it independently, so we have both open for reference until we converge on a single implementation since neither is working yet.

@jwjohns
Copy link
Contributor Author

jwjohns commented Aug 26, 2025

Is this a duplicate of #15507?

apologies. I did wait rather than just create the duplicate immediately granted I should have set the PR to his fork and branch.

{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },

// Nemotron-H specific
{ LLM_KV_LAYER_TYPES, "%s.layer_types" },
Copy link
Collaborator

@gabe-l-hart gabe-l-hart Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get away with not adding this new hparam. This is similar to a piece of feedback I got during #13550 (it's a looong PR, but it's in there somewhere). I had introduced a new array hparam similar to this one (mine was a bool), but @compilade pointed out that we could extract the same information by setting n_head_kv to an array value during conversion and then reading it per-layer (here). In this case, we can leverage n_ff in the same way so that the layer types are determined as:

  1. n_head_kv == 0 && n_ff == 0 => recurrent
  2. n_head_kv == 0 && n_ff > 0 => MLP
  3. n_head_kv > 0 && n_ff == 0 => attention
  4. n_head_kv >0 && n_ff > 0 => INVALID (or maybe valid for a future architecture??)

const int64_t n_rs = mctx->get_n_rs();

if (s_copy) {
// Check if buffer was allocated - skip if not
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think anything about the architecture should require these changes to the core recurrent graph structures. Can you clarify what condition led you to adding these conditional checks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gabe-l-hart You're right. This was a workaround for a crash where the graph
was trying to access recurrent state for all 56 layers, but i think Nemotron-H only allocates
recurrent state for the 27 SSM layers (not attention/MLP layers).

still learning.

}

template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
template bool llama_model_loader::get_arr<std::vector<unsigned char>>(enum llm_kv kid, std::vector<unsigned char> & result, bool required);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we get rid of the new hparam, these template specializations won't be needed anymore (but nice job finding them, it took me a loong time to find them myself, and I have to re-find them every time)

completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
fprintf(stderr, "[DETOKENIZE] Token ID: %d -> Text: '%s' (length: %zu)\n", result.tok, result.text_to_send.c_str(), result.text_to_send.length());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love adding useful logs like this! This project has its own set of logging macros to use. In the core of the project, you've got LLAMA_LOG_* (defined here). In the server tool, there are three sets defined here: SVR_* for server-level logs, SLT_* for slot-level logs, and QUE_* for queue-level logs.

for (; begin != end; ++begin) {
ret += common_token_to_piece(ctx, *begin);
std::string piece = common_token_to_piece(ctx, *begin);
fprintf(stderr, "[DEBUG] Token ID: %d -> Piece: '%s' (length: %zu)\n", *begin, piece.c_str(), piece.length());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about logging for these ones. They're definitely useful, so it would be great to make these proper logs!


// Try to load layer schedule from GGUF: %s.layer_types (0=SSM,1=ATTN,2=FFN)
std::vector<int32_t> layer_types;
const bool has_schedule = ml.get_arr(LLM_KV_LAYER_TYPES, layer_types, false) && layer_types.size() == hparams.n_layer;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we leverage n_ff for this, this parsing gets a lot easier!

ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

// Nemotron-H attention parameters (fixed per public config)
hparams.n_embd_head_k = 128; // attention head size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a final version of this PR, it will be best to avoid these hard-coded numbers so that the architecture remains independent of the specific model instance we're building it for. Of course, when getting it all working these are fine and a good way to isolate variables between conversion and loading.

const int64_t n_group = hparams.ssm_n_group;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
// Use actual dimension from model: 22656 instead of calculated 22608
const int64_t d_in_proj = 22656; // 2*d_inner + 2*n_group*d_state + n_head + 48;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my PR, I avoided the need to hard code this by setting d_inner as mamba_num_heads (128) * mamba_head_dim (80), setting n_group to mamba_num_groups (8), setting d_state to mamba_state_dim (129), and setting n_head to head_dim (NOTE: not mamba_head_dim!). This works out to 2*128*80 + 2*8*128 + 128 == 22656.

/* unified */ cparams.kv_unified,
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_NEMOTRON_H) ?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad you found this! Since we now have n == 2 models that need this pattern, I've tried to make it a little cleaner by having a section of if/else cases to define architecture-specific filter lambdas (here)

Copy link
Collaborator

@gabe-l-hart gabe-l-hart Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aaaand, it looks like my version is broken somehow! EDIT: fixed (sloppy typo)

@jwjohns
Copy link
Contributor Author

jwjohns commented Aug 26, 2025

i really appreciate the feedback!

# for security reason, we don't allow loading remote code by default
# if a model need remote code, we will fallback to config.json
config = AutoConfig.from_pretrained(dir_model, trust_remote_code=False).to_dict()
config = AutoConfig.from_pretrained(dir_model, trust_remote_code=True).to_dict()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got tired of typing it. Temporary.

print(f"DEBUG: Failed metadata key type: {type(val)}")
print(f"DEBUG: Failed metadata value: {val}")
print(f"DEBUG: Caller info available in stack trace")
raise ValueError(f"Invalid GGUF metadata array, expecting sequence but got {type(val)}: {val}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more debug, didnt mean to commit. will clean up.

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Aug 27, 2025
@gabe-l-hart
Copy link
Collaborator

Thanks for the great work here! Closing this now that we've consolidated in #15507

@jwjohns
Copy link
Contributor Author

jwjohns commented Aug 29, 2025

@gabe-l-hart really appreciate the help!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning python python script changes server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants